'''
Author: SlytherinGe
LastEditTime: 2021-03-29 21:39:30
'''
import torch


if __name__ == '__main__':
    
    model_file = torch.load('/home/gejunyao/.cache/torch/hub/checkpoints/backup/resnet50-19c8e357.pth')
    # for key, value in model_file.items():
    #     print(key)

    # # resnet50
    conv1 = model_file['conv1.weight']
    print(conv1.shape)
    # empty = torch.randn(64, 1, 7, 7)
    # empty = empty / 1000.0
    # result = torch.cat((conv1, empty), 1)
    # print(result.shape)
    # model_file['conv1.weight']=result
    # torch.save(model_file, '/home/gejunyao/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth')

    # darknet53
    # conv1 = model_file['state_dict']['conv1.conv.weight']
    # print(conv1.shape)
    # empty = torch.randn(32, 1, 3, 3)
    # empty = empty / 1000.0
    # result = torch.cat((conv1, empty), 1)
    # print(result.shape)
    # model_file['state_dict']['conv1.conv.weight']=result
    # torch.save(model_file, '/home/gejunyao/.cache/torch/hub/checkpoints/darknet53-a628ea1b.pth')

    # vgg16
    # conv1 = model_file['features.0.weight']
    # print(conv1.shape)
    # empty = torch.randn(64, 1, 3, 3)
    # empty = empty / 1000.0
    # result = torch.cat((conv1, empty), 1)
    # print(result.shape)
    # model_file['conv1.weight']=result
    # torch.save(model_file, '/home/gejunyao/.cache/torch/hub/checkpoints/vgg16_caffe-292e1171.pth')